Skip to content

[torch.compile] Bunch of small changes needed for enabling torch.compile#3130

Open
pggPL wants to merge 5 commits into
NVIDIA:mainfrom
pggPL:torch_compile_small_fixes
Open

[torch.compile] Bunch of small changes needed for enabling torch.compile#3130
pggPL wants to merge 5 commits into
NVIDIA:mainfrom
pggPL:torch_compile_small_fixes

Conversation

@pggPL

@pggPL pggPL commented Jun 15, 2026

Copy link
Copy Markdown
Collaborator

Description

Small standalone fixes extracted from a larger torch.compile branch, going directly from main. Two independent changes: making Userbuffers pybind11 queries compile-friendly, and freeing quantized grad_output early for column-parallel SP. Plus a custom-recipe caching fix, a split-accumulator refactor, and a CI test hook-up.

Type of change

  • Documentation change (change only to the documentation, either a fix or a new content)
  • Bug fix (non-breaking change which fixes an issue)
  • New feature (non-breaking change which adds functionality)
  • Breaking change (fix or feature that would cause existing functionality to not work as expected)
  • Infra/Build change
  • Code refactoring

Changes

  1. Userbuffers pybind11 queries under torch.compile
  • is_fp8_ubuf() / with_cublasmp() are compile-time constants but graph-break when traced. At the nn.Module.forward boundary (where no UB communicator object is in hand yet) they go through get_ub_is_fp8(name, use_fp8), wrapped in torch.compiler.assume_constant_result — only plain (str, bool) args are baked, so guards are well-defined and don't rely on pybind-object identity.
  • In the hot forward/backward implementation paths the UB communicator is already fetched, so those call ub_obj.is_fp8_ubuf() / ub_obj.with_cublasmp() directly — no wrapper, no string concatenation, no redundant registry lookup. Eager speed is preserved.
  1. Free quantized grad_output early for column-parallel SP
  • Row-parallel SP already called clear_tensor_data(grad_output) on the gathered tensor. Column-parallel SP quantizes grad_output to a Float8TensorStorage (an internal tensor) but never freed it. Under torch.compile reduce-overhead this left live pool tensors at recording end ("Detected N tensor(s) in the cudagraph pool not tracked as outputs"). The free now covers row-SP and column-SP-FP8 (column-SP non-FP8 is a no-op view, so it's excluded).
  1. Replace fp8_recipe in LinearBwdArgs with pre-resolved split-accumulator booleans
  • LinearBwdArgs no longer carries the recipe object (which holds process-group references and is compile-unfriendly). dgrad_use_split_accumulator / wgrad_use_split_accumulator are resolved once in Linear.forward (reusing the existing get_fp8_recipe() call) and threaded through as plain booleans.
  1. Custom-recipe quantizer caching fix
  • CustomRecipeState early-exit was missing an identity check, so quantizers were rebuilt on every forward even when the recipe was unchanged. Added if recipe_state.recipe is recipe: return.
  1. Test hook-up
  • Added test_torch_compile.py to L0_pytorch_unittest.

Checklist:

  • I have read and followed the contributing guidelines
  • The functionality is complete
  • I have commented my code, particularly in hard-to-understand areas
  • I have made corresponding changes to the documentation
  • My changes generate no new warnings
  • I have added tests that prove my fix is effective or that my feature works
  • New and existing unit tests pass locally with my changes

pggPL and others added 2 commits June 15, 2026 16:40
…stants; fix SP memory leak; test suite hook-up

Wrap CommOverlapCore pybind11 methods that return compile-time constants
so torch.compile(fullgraph=True) can trace through them without graph
breaks:
- `is_fp8_ubuf()` → `ub_is_fp8()` / `get_ub_is_fp8()` in base.py;
  `_ub_is_fp8()` in gemm.py
- `with_cublasmp()` → `ub_is_cublasmp()` in base.py

All callers in linear.py, layernorm_linear.py, layernorm_mlp.py,
base.py, gemm.py, userbuffers_backward_linear.py and
userbuffers_forward_linear.py updated.

Fix quantized grad_output not being freed early for column-parallel SP
backward. Row-parallel SP already called clear_tensor_data(grad_output)
to release the gathered tensor; column-parallel SP quantizes grad_output
to Float8TensorStorage but never freed it before returning.  Under
torch.compile reduce-overhead this leaves 3 live pool tensors at
recording end and triggers "Detected 3 tensor(s) in the cudagraph pool
not tracked as outputs".  Extend the existing clear_tensor_data guard to
cover both parallel modes.

Fix custom-recipe quantizer state being re-initialised on every forward
call even when the recipe object has not changed. The existing early-exit
for CustomRecipeState was missing an identity check on the recipe object,
so any repeated call with the same recipe would bypass the early-return
and rebuild quantizers unnecessarily.  Add `if recipe_state.recipe is
recipe: return` to restore the intended caching behaviour.

Add test_torch_compile.py to L0_pytorch_unittest so the autocast and
existing compile tests run in CI.

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
Signed-off-by: Pawel Gadzinski <pgadzinski@nvidia.com>
…-accumulator booleans

LinearBwdArgs stored the entire FP8 recipe object so the backward could
extract fp8_gemm_dgrad.use_split_accumulator and
fp8_gemm_wgrad.use_split_accumulator at GEMM time.  Recipe objects hold
process-group references and are not serialisable as compile-time
constants, making them incompatible with torch.compile custom-op paths.

Replace fp8_recipe with two plain bool fields:
- dgrad_use_split_accumulator (default _2X_ACC_DGRAD)
- wgrad_use_split_accumulator (default _2X_ACC_WGRAD)

These are resolved once in _linear_setup_ctx and passed into the args
struct, so the backward consumes scalars instead of a live recipe object.

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
Signed-off-by: Pawel Gadzinski <pgadzinski@nvidia.com>
@pggPL pggPL requested a review from ksivaman as a code owner June 15, 2026 14:41
@greptile-apps

greptile-apps Bot commented Jun 15, 2026

Copy link
Copy Markdown
Contributor

Greptile Summary

This PR bundles five targeted torch.compile-readiness fixes extracted from a larger branch. Each change is independent and non-breaking.

  • UB pybind11 graph-break fix: get_ub_is_fp8() wraps is_fp8_ubuf() with @torch.compiler.assume_constant_result so the module boundary calls are folded as compile-time constants; destroy_ub() now calls torch.compiler.reset() to invalidate those constants on teardown.
  • Column-parallel SP FP8 pool-tensor leak: clear_tensor_data(grad_output) is extended to cover the column-parallel SP FP8 path (in addition to row-SP) in both linear.py and layernorm_linear.py — the quantized Float8TensorStorage is freed immediately after the wgrad GEMM, after which it is no longer referenced.
  • LinearBwdArgs recipe-object removal: fp8_recipe is replaced by two plain booleans (dgrad_use_split_accumulator, wgrad_use_split_accumulator) resolved once at forward time, eliminating a compile-unfriendly process-group reference from the autograd context.
  • CustomRecipeState rebuild fix: Adds an is identity guard so quantizers are not re-built on every forward when the custom recipe object is unchanged.
  • CI hook-up: test_torch_compile.py is added to the L0 test suite; the test's qfactory is corrected to dispatch on QuantizerRole.tensor_type (the previous string-keyed dict would have raised KeyError at runtime).

Confidence Score: 5/5

All changes are well-scoped and non-breaking; the column-SP FP8 tensor free is correctly placed after the wgrad GEMM and the freed variable is not accessed again.

Each change is narrowly targeted: the split-accumulator refactor preserves the same defaults for non-FP8 and reproduces the same hasattr guards for FP8 recipes; the grad_output free only triggers after the GEMM that consumed it; the CustomRecipeState identity check uses is which is appropriate for a mutable recipe object; and torch.compiler.reset() in destroy_ub() correctly handles stale cached constants. No correctness regressions were found across the traced code paths.

No files require special attention.

Important Files Changed

Filename Overview
transformer_engine/pytorch/module/base.py Adds get_ub_is_fp8() wrapper with @assume_constant_result, adds identity check to CustomRecipeState early-exit, and calls torch.compiler.reset() in destroy_ub() to invalidate cached constants after UB teardown.
transformer_engine/pytorch/module/linear.py Replaces fp8_recipe in LinearBwdArgs with pre-resolved split-accumulator booleans, extends grad_output early-free to column-parallel SP FP8 (after wgrad GEMM), and switches UB fp8 queries to get_ub_is_fp8().
transformer_engine/pytorch/module/layernorm_linear.py Applies the same column-parallel SP FP8 grad_output early-free as linear.py, and switches UB fp8 queries to get_ub_is_fp8(). ctx.fp8_recipe remains for split-accumulator (not in scope of this PR).
transformer_engine/pytorch/module/layernorm_mlp.py Single-line change: switches get_ub(...).is_fp8_ubuf() to get_ub_is_fp8() in the fc2_fprop path for compile-time constant folding.
tests/pytorch/test_torch_compile.py Adds get_quantizer_roles() to ToyLinear module and fixes qfactory to dispatch on QuantizerRole.tensor_type; previous string keys like 'linear_input' would have caused a KeyError.
qa/L0_pytorch_unittest/test.sh Adds test_torch_compile.py to the L0 CI test suite.

Sequence Diagram

%%{init: {'theme': 'neutral'}}%%
sequenceDiagram
    participant M as Linear.forward()
    participant GSM as FP8GlobalStateManager
    participant FA as LinearFwdArgs
    participant BA as LinearBwdArgs
    participant BW as _linear_backward()

    M->>GSM: get_fp8_recipe()
    GSM-->>M: _recipe
    M->>M: resolve dgrad/wgrad split-accumulator bools
    M->>FA: fwd_args(dgrad_use_split_accumulator, wgrad_use_split_accumulator)
    FA->>BA: _linear_setup_ctx copies plain bools
    Note over FA,BA: No recipe object stored

    BW->>BW: "use_split_accumulator = bwd_args.dgrad_use_split_accumulator"
    BW->>BW: dgrad GEMM
    BW->>BW: "use_split_accumulator = bwd_args.wgrad_use_split_accumulator"
    BW->>BW: wgrad GEMM
    BW->>BW: clear_tensor_data(grad_output) [row-SP or col-SP+FP8]
Loading
%%{init: {'theme': 'base', 'themeVariables': {"darkMode": true, "background": "#0d1117", "primaryColor": "#21262d", "primaryTextColor": "#e6edf3", "primaryBorderColor": "#8b949e", "lineColor": "#8b949e", "textColor": "#e6edf3", "edgeLabelBackground": "#161b22", "actorBkg": "#21262d", "actorBorder": "#8b949e", "actorTextColor": "#e6edf3", "actorLineColor": "#8b949e", "signalColor": "#8b949e", "signalTextColor": "#e6edf3", "noteBkgColor": "#373320", "noteBorderColor": "#d4a72c", "noteTextColor": "#f0e6c0", "labelBoxBkgColor": "#21262d", "labelBoxBorderColor": "#8b949e", "labelTextColor": "#e6edf3", "loopTextColor": "#e6edf3", "activationBkgColor": "#30363d", "activationBorderColor": "#8b949e"}}}%%
sequenceDiagram
    participant M as Linear.forward()
    participant GSM as FP8GlobalStateManager
    participant FA as LinearFwdArgs
    participant BA as LinearBwdArgs
    participant BW as _linear_backward()

    M->>GSM: get_fp8_recipe()
    GSM-->>M: _recipe
    M->>M: resolve dgrad/wgrad split-accumulator bools
    M->>FA: fwd_args(dgrad_use_split_accumulator, wgrad_use_split_accumulator)
    FA->>BA: _linear_setup_ctx copies plain bools
    Note over FA,BA: No recipe object stored

    BW->>BW: "use_split_accumulator = bwd_args.dgrad_use_split_accumulator"
    BW->>BW: dgrad GEMM
    BW->>BW: "use_split_accumulator = bwd_args.wgrad_use_split_accumulator"
    BW->>BW: wgrad GEMM
    BW->>BW: clear_tensor_data(grad_output) [row-SP or col-SP+FP8]
Loading

Reviews (3): Last reviewed commit: "Provide explicit QuantizerRoles in torch..." | Re-trigger Greptile

Comment on lines +557 to +560
@torch.compiler.assume_constant_result
def get_ub_is_fp8(name: str, use_fp8: bool) -> bool:
"""Query is_fp8_ubuf for a named UB communicator; treated as compile-time constant."""
return get_ub(name, use_fp8).is_fp8_ubuf()

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P2 assume_constant_result can become stale after destroy_ub() + re-init

@torch.compiler.assume_constant_result caches the return value per (name, use_fp8) argument pair for the lifetime of a compiled region. If destroy_ub() is called and UB communicators are re-initialized with different FP8 settings (e.g. in a test harness that re-creates the communicators), the cached is_fp8_ubuf() result would be silently stale until the next recompile. In production training this should not happen — UB is typically initialized once — but test suites that tear down and rebuild UB communicators between cases could observe incorrect fp8_output/fp8_grad flags without triggering a recompile.

@pggPL

pggPL commented Jun 16, 2026

Copy link
Copy Markdown
Collaborator Author

/te-ci pytorch L1

pggPL added 2 commits June 16, 2026 14:05
…t_result

get_ub_is_fp8 bakes is_fp8_ubuf() as a compile-time constant; without a
reset, destroy_ub + re-init with different FP8 settings would read stale
values until recompile. Only affects in-memory caches, not disk.

Signed-off-by: Pawel Gadzinski <pgadzinski@nvidia.com>
ToyLinear now overrides get_quantizer_roles so CustomRecipeState doesn't hit
the no-roles warning, which graph-breaks under fullgraph=True. qfactory
dispatches on role.tensor_type instead of a pre-baked string key.

Signed-off-by: Pawel Gadzinski <pgadzinski@nvidia.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant